#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2016-2018, Niklas Hauser
# Copyright (c) 2017, Fabian Greif
# Copyright (c) 2020, Mike Wolfram
# Copyright (c) 2021, Raphael Lehmann
# Copyright (c) 2021-2023, Christopher Durand
#
# This file is part of the modm project.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# -----------------------------------------------------------------------------

from collections import defaultdict
import re

def init(module):
    module.name = ":platform:dma"
    module.description = "Direct Memory Access (DMA)"


def prepare(module, options):
    device = options[":target"]

    module.depends(":cmsis:device", ":platform:rcc")

    if (not device.has_driver("dma")) or device.get_driver("dma")["type"] not in \
            [
            "stm32-channel-request",
            "stm32-channel",
            "stm32-mux",
            "stm32-mux-stream",
            "stm32-stream-channel",
    ]:
        return False
    return True


def get_irq_list(device):
    irqs = {v["position"]: v["name"] for v in device.get_driver("core")["vector"]}
    irqs = [v for v in irqs.values() if re.match(r"(?:\w+_)?DMA\d_", v)]

    instance_pattern = re.compile(r"(DMA\d_(Ch|CH|Channel|Stream)\d(_\d)*)")
    channel_pattern = re.compile(r"DMA(?P<instance>\d)_(Ch|CH|Channel|Stream)(?P<channels>(\d(_\d)*))")
    irq_list = []
    for irq in irqs:
        instances = []
        # get channel definitions for each instance, e.g. [DMA1_Ch4_7, DMA2_Ch1_5] or [DMA1_Stream0, DMA2_Stream3]
        instance_data = [match[0] for match in instance_pattern.findall(irq)]
        if len(instance_data) == 0:
            raise ValueError("Failed to parse irq '{}'".format(irq))
        for channel_list in instance_data:
            match = channel_pattern.match(channel_list)
            if match is None:
                raise ValueError("Failed to parse irq channels: '{}'".format(channel_list))
            instance = int(match.group("instance"))
            channels = [int(ch) for ch in match.group('channels').split('_')]
            if len(channels) == 2:
                channels = list(range(channels[0], channels[1] + 1))
            instances.append({"instance" : instance, "channels" : channels})
        irq_list.append({'name' : irq, 'instances' : instances})
    return irq_list


def build(env):
    device = env[":target"]

    properties = device.properties
    properties["target"] = device.identifier
    dma = device.get_driver("dma")
    properties["dma"] = dma

    if dma["type"] in ["stm32-stream-channel"]:
        dma["channels"] = dma["streams"]

    signal_names = {}
    controller = []
    if dma["type"] in ["stm32-channel-request", "stm32-channel"]:
        # Get the peripheral supported by DMA from device info and create a list of signals
        # (also determines the maximum number of channels per controller)
        for channels in dma["channels"]:
            max_channels = 0
            for channel in channels["channel"]:
                max_channels = channel["position"]
                if dma["type"] in ["stm32-channel-request"]:
                    # Some F0 allow to use multiple requests for the same peripheral signal
                    # on the same channel. This is not supported with the current api.
                    # Therefore, signals are deduplicated.
                    signal_set = set()
                    duplicates = []
                    for request in channel["request"]:
                        for signal in request["signal"]:
                            signal_data = (signal.get("driver"),
                                           signal.get("instance"),
                                           signal.get("name"))
                            if signal_data in signal_set:
                                duplicates.append((channel, request, signal))
                            else:
                                signal_set.add(signal_data)
                            if "name" in signal:
                                signal_name = signal["name"].capitalize()
                                signal_names[signal_name] = 1
                    for duplicate in duplicates:
                        (channel, request, signal) = duplicate
                        request["signal"].remove(signal)
                        if len(request["signal"]) == 0:
                            channel["request"].remove(request)
                else:
                    for signal in channel["signal"]:
                        if "name" in signal:
                            signal_name = signal["name"].capitalize()
                            signal_names[signal_name] = 1
            controller.append({"instance": int(channels["instance"]),
                               "min_channel": 1, "max_channel": int(max_channels)})
    elif dma["type"] in ["stm32-mux", "stm32-mux-stream"]:
        for request_data in dma["requests"]:
            for request in request_data["request"]:
                for signal in request["signal"]:
                    if "name" in signal:
                        signal_name = signal["name"].capitalize()
                        signal_names[signal_name] = 1

        assert len(dma["mux-channels"]) == 1  # only one DMAMUX instance is supported
        channels = dma["mux-channels"][0]["mux-channel"]
        instance_channels = defaultdict(list)
        for channel in channels:
            instance_channels[channel["dma-instance"]].append(channel["dma-channel"])
        for instance in instance_channels:
            channel_list = [int(c) for c in instance_channels[instance]]
            channel_list.sort()
            min_channel = channel_list[0]
            max_channel = channel_list[-1]
            controller.append({"instance": int(instance), "min_channel": min_channel, "max_channel": max_channel})
    elif dma["type"] in ["stm32-stream-channel"]:
        for channels in dma["streams"]:
            max_channels = 0
            for channel in channels["stream"]:
                max_channels = channel["position"]
                for request in channel["channel"]:
                    for signal in request["signal"]:
                        if "name" in signal:
                            signal_name = signal["name"].capitalize()
                            signal_names[signal_name] = 1
            controller.append({"instance": int(channels["instance"]),
                               "min_channel": 0, "max_channel": int(max_channels)})
    else:
        raise NotImplementedError

    did = device.identifier
    if (did.family in ['f0'] and did.name == '30' and did.size == 'c') or \
        (did.family in ['f1'] and did.name not in ['05', '07'] and did.size in ['4', '6', '8', 'b']):
        # FIXME: Bug in modm-devices data: Dma2 does not exist on device.identifier
        properties["dma"]["instance"].remove('2')
        controller = list(filter(lambda c: c["instance"] != 2, controller))
        for channels in dma["channels"]:
            if channels["instance"] == '2':
                dma["channels"].remove(channels)

    # With 'stm32-stream-channel' architecture some {peripherals, ~stream~ channel}
    # can be used on multiple ~channels~ requests. This is not supported with the
    # current API. The requests are de-duplicated:
    if dma["type"] in ["stm32-stream-channel"]:
        for channels in dma["streams"]:
            channel_peripheral_list = list()
            for channel in channels["stream"]:
                for request in channel["channel"]:
                    for signal in request["signal"]:
                        peripheral = signal["driver"]
                        if "instance" in signal and peripheral not in ["Dac"]:
                            peripheral = peripheral + str(signal["instance"])
                        name = signal["name"] if "name" in signal else ""
                        channel_peripheral = "DMA{}, Channel{}, {}, {}".format(channels["instance"], channel["position"], peripheral, name)
                        if channel_peripheral in channel_peripheral_list:
                            for request in channel["channel"]:
                                if signal in request["signal"]:
                                    #print("Removing '{}' because there already exists another request...".format(channel_peripheral))
                                    request["signal"].remove(signal)
                                    break
                        else:
                            channel_peripheral_list.append(channel_peripheral)

    signal_names = sorted(list(set(signal_names)))
    properties["dmaType"] = dma["type"]
    properties["dmaSignals"] = signal_names
    properties["dmaController"] = controller

    properties['channel_count'] = {
        "min" : min(controller, key=lambda c: c["min_channel"])["min_channel"],
        "max" : max(controller, key=lambda c: c["max_channel"])["max_channel"]
    }

    # remove channels not available from irqs
    def filter_irqs(irq_list):
        max_channel_map = {}
        for c in controller:
            max_channel_map[c['instance']] = c['max_channel']
        for irq in irq_list:
            for instance in irq["instances"]:
                def predicate(ch):
                    return ch <= max_channel_map[instance["instance"]]
                instance["channels"] = list(filter(predicate, instance["channels"]))
        return irq_list

    properties["irqList"] = filter_irqs(get_irq_list(device))

    irq_map = defaultdict(dict)
    for irq in properties["irqList"]:
        for instance in irq["instances"]:
            for channel in instance["channels"]:
                irq_map[instance["instance"]][channel] = irq["name"]
    properties["irqMap"] = irq_map

    env.substitutions = properties
    env.outbasepath = "modm/src/modm/platform/dma"

    env.template("dma_base.hpp.in")
    env.template("dma_hal.hpp.in")
    env.template("dma_hal_impl.hpp.in")
    env.template("dma.hpp.in")
    env.template("dma.cpp.in")

